import os
import uuid
import types
from dataclasses import asdict, dataclass
from typing import Any, DefaultDict, Dict, List, Optional, Tuple

import bullet_safety_gym  # noqa
import dsrl
import gymnasium as gym  # noqa
import gym as gym_org
import numpy as np
import pyrallis
import torch
from dsrl.infos import DENSITY_CFG
from dsrl.offline_env import OfflineEnvWrapper, wrap_env  # noqa
from fsrl.utils import WandbLogger, TensorboardLogger
from torch.utils.data import DataLoader
from tqdm.auto import trange  # noqa
from collections import defaultdict

from examples.configs.bcql_configs import BCQL_DEFAULT_CONFIG, BCQLTrainConfig
from osrl.algorithms import BCQL, BCQLTrainer
from osrl.common import TransitionDataset
from osrl.common.exp_util import auto_name, seed_all
import pickle
from osrl.common.cost_functions import env2cost_dict
from osrl.algorithms import EnsembleDynamics, EnsembleDynamicsModel, EnsembleCostModel
from osrl.common.net import StandardScaler, SimpleScaler, termination_fn_common
from copy import deepcopy as dco

# env2dynamics = {
#     "OfflinePointButton1Gymnasium-v0": "logs_new/OfflinePointButton1Gymnasium-v0/environment_model_safe_onlyTrue_simple_scalerTrue-364b/environment_model_safe_onlyTrue_simple_scalerTrue-364b/model",
#     "OfflinePointButton2Gymnasium-v0": "logs_new/OfflinePointButton2Gymnasium-v0/environment_model_safe_onlyTrue_simple_scalerTrue-a4f2/environment_model_safe_onlyTrue_simple_scalerTrue-a4f2/model",
#     "OfflinePointCircle1Gymnasium-v0": "logs_new/OfflinePointCircle1Gymnasium-v0/environment_model_safe_onlyTrue_simple_scalerTrue-c4c3/environment_model_safe_onlyTrue_simple_scalerTrue-c4c3/model",
#     "OfflinePointCircle2Gymnasium-v0": "logs_new/OfflinePointCircle2Gymnasium-v0/environment_model_safe_onlyTrue_simple_scalerTrue-2427/environment_model_safe_onlyTrue_simple_scalerTrue-2427/model",
#     "OfflinePointGoal1Gymnasium-v0": "logs_new/OfflinePointGoal1Gymnasium-v0/environment_model_safe_onlyTrue_simple_scalerTrue-412b/environment_model_safe_onlyTrue_simple_scalerTrue-412b/model",
#     "OfflinePointGoal2Gymnasium-v0": "logs_new/OfflinePointGoal2Gymnasium-v0/environment_model_safe_onlyTrue_simple_scalerTrue-c822/environment_model_safe_onlyTrue_simple_scalerTrue-c822/model",
#     "OfflinePointPush1Gymnasium-v0": "logs_new/OfflinePointPush1Gymnasium-v0/environment_model_safe_onlyTrue_simple_scalerTrue-9c60/environment_model_safe_onlyTrue_simple_scalerTrue-9c60/model",
#     "OfflinePointPush2Gymnasium-v0": "logs_new/OfflinePointPush2Gymnasium-v0/environment_model_safe_onlyTrue_simple_scalerTrue-73f9/environment_model_safe_onlyTrue_simple_scalerTrue-73f9/model",
#     "OfflineCarButton1Gymnasium-v0": "logs_new/OfflineCarButton1Gymnasium-v0/environment_model_safe_onlyTrue_simple_scalerTrue-06da/environment_model_safe_onlyTrue_simple_scalerTrue-06da/model",
#     "OfflineCarButton2Gymnasium-v0": "logs_new/OfflineCarButton2Gymnasium-v0/environment_model_safe_onlyTrue_simple_scalerTrue-b6f2/environment_model_safe_onlyTrue_simple_scalerTrue-b6f2/model",
#     "OfflineCarCircle1Gymnasium-v0": "logs_new/OfflineCarCircle1Gymnasium-v0/environment_model_safe_onlyTrue_simple_scalerTrue-c7c6/environment_model_safe_onlyTrue_simple_scalerTrue-c7c6/model",
#     "OfflineCarCircle2Gymnasium-v0": "logs_new/OfflineCarCircle2Gymnasium-v0/environment_model_safe_onlyTrue_simple_scalerTrue-fe46/environment_model_safe_onlyTrue_simple_scalerTrue-fe46/model",
#     "OfflineCarGoal1Gymnasium-v0": "logs_new/OfflineCarGoal1Gymnasium-v0/environment_model_safe_onlyTrue_simple_scalerTrue-a583/environment_model_safe_onlyTrue_simple_scalerTrue-a583/model",
#     "OfflineCarGoal2Gymnasium-v0": "logs_new/OfflineCarGoal2Gymnasium-v0/environment_model_safe_onlyTrue_simple_scalerTrue-15f6/environment_model_safe_onlyTrue_simple_scalerTrue-15f6/model",
#     "OfflineCarPush1Gymnasium-v0": "logs_new/OfflineCarPush1Gymnasium-v0/environment_model_safe_onlyTrue_simple_scalerTrue-78a7/environment_model_safe_onlyTrue_simple_scalerTrue-78a7/model",
#     "OfflineCarPush2Gymnasium-v0": "logs_new/OfflineCarPush2Gymnasium-v0/environment_model_safe_onlyTrue_simple_scalerTrue-3c2a/environment_model_safe_onlyTrue_simple_scalerTrue-3c2a/model",
#     'OfflineAntVelocityGymnasium-v1': "logs_new/OfflineAntVelocityGymnasium-v1/environment_model_safe_onlyTrue_simple_scalerTrue-1417/environment_model_safe_onlyTrue_simple_scalerTrue-1417/model",          # 16
#     'OfflineHalfCheetahVelocityGymnasium-v1': "logs_new/OfflineHalfCheetahVelocityGymnasium-v1/environment_model_safe_onlyTrue_simple_scalerTrue-0a2e/environment_model_safe_onlyTrue_simple_scalerTrue-0a2e/model",  # 17
#     'OfflineHopperVelocityGymnasium-v1': "logs_new/OfflineHopperVelocityGymnasium-v1/environment_model_safe_onlyTrue_simple_scalerTrue-c2d9/environment_model_safe_onlyTrue_simple_scalerTrue-c2d9/model",       # 18
#     'OfflineSwimmerVelocityGymnasium-v1': "logs_new/OfflineSwimmerVelocityGymnasium-v1/environment_model_safe_onlyTrue_simple_scalerTrue-0692/environment_model_safe_onlyTrue_simple_scalerTrue-0692/model",      # 19
#     'OfflineWalker2dVelocityGymnasium-v1': "logs_new/OfflineWalker2dVelocityGymnasium-v1/environment_model_safe_onlyTrue_simple_scalerTrue-512e/environment_model_safe_onlyTrue_simple_scalerTrue-512e/model",     # 20
#     "PointRobot": "logs_new/PointRobot/environment_model_safe_onlyTrue_simple_scalerTrue-e67a/environment_model_safe_onlyTrue_simple_scalerTrue-e67a/model",
#     "OfflineAntCircle-v0": "logs_new/OfflineAntCircle-v0/environment_model_safe_onlyTrue_simple_scalerTrue-8b6e/environment_model_safe_onlyTrue_simple_scalerTrue-8b6e/model"
# }
env2dynamics = {}

def rollout(
    init_obss, rollout_length, trainer, dynamics, cost_func, exp_sigma = 0.1, use_unsafe_mask=True,
):
    # episode_rets, episode_costs, episode_lens, episode_no_safes = [], [], [], []
    rewards_arr = np.array([])
    costs_arr = np.array([])
    rollout_transitions = defaultdict(list)
    num_transitions = 0

    observations = init_obss
    unsafe_mask = None
    trainer.model.eval()
    for _ in range(rollout_length):
        actions, _ = trainer.model.act(observations)
        sigma = np.ones_like(actions) * exp_sigma
        actions = np.clip(np.random.normal(actions, sigma), -1.0, 1.0)
        # print(observations.shape, actions.shape)
        # assert False
        next_observations, rewards, terminals, info = dynamics.safe_step(observations, actions, cost_func)
        rollout_transitions["observations"].append(observations)
        rollout_transitions["next_observations"].append(next_observations)
        rollout_transitions["actions"].append(actions)
        rollout_transitions["dones"].append(terminals)
        rollout_transitions["rewards"].append(rewards)
        rollout_transitions["costs"].append(info["cost"])

        num_transitions += len(observations)
        rewards_arr = np.append(rewards_arr, rewards.flatten())
        costs_arr = np.append(costs_arr, info["cost"].flatten())

        if unsafe_mask is None:
            unsafe_mask = (info["cost"] > 0)
        else:
            unsafe_mask = np.logical_or(unsafe_mask, (info["cost"] > 0))

        nonterm_mask = (~terminals).flatten()
        if nonterm_mask.sum() == 0:
            break
        observations = next_observations[nonterm_mask]
    
    trainer.model.train()
    unsafe_mask_ls = []
    for _ in range(rollout_length):
        unsafe_mask_ls.append(unsafe_mask)
    unsafe_mask = np.concatenate(unsafe_mask_ls, axis=0).reshape(-1,)

    for k, v in rollout_transitions.items():
        rollout_transitions[k] = np.concatenate(v, axis=0)
    if use_unsafe_mask:
        for key in rollout_transitions.keys():
            rollout_transitions[key] = rollout_transitions[key][unsafe_mask]
    return rollout_transitions, \
        {"num_transitions": rollout_transitions['observations'].shape[0], "reward_mean": rewards_arr.mean(), "cost_mean": costs_arr.mean()}

@pyrallis.wrap()
def train(args: BCQLTrainConfig):
    # update config
    cfg, old_cfg = asdict(args), asdict(BCQLTrainConfig())
    differing_values = {key: cfg[key] for key in cfg.keys() if cfg[key] != old_cfg[key]}
    cfg = asdict(BCQL_DEFAULT_CONFIG[args.task]())
    cfg.update(differing_values)
    args = types.SimpleNamespace(**cfg)

    # setup logger
    default_cfg = asdict(BCQL_DEFAULT_CONFIG[args.task]())
    if args.name is None:
        args.name = auto_name(default_cfg, cfg, args.prefix, args.suffix)
    if args.group is None:
        args.group = args.task + "-cost-" + str(int(args.cost_limit))
    if args.logdir is not None:
        args.logdir = os.path.join(args.logdir, args.group, args.name)
    #logger = WandbLogger(cfg, args.project, args.group, args.name, args.logdir)
    logger = TensorboardLogger(args.logdir, log_txt=True, name=args.name)
    logger.save_config(cfg, verbose=args.verbose)

    # set seed
    seed_all(args.seed)
    if args.device == "cpu":
        torch.set_num_threads(args.threads)

    # initialize environment
    if "Metadrive" in args.task:
        # import gym
        env = gym_org.make(args.task)
    else:
        env = gym.make(args.task)

    # pre-process offline dataset
    data = env.get_dataset()
    env.set_target_cost(args.cost_limit)

    cbins, rbins, max_npb, min_npb = None, None, None, None
    if args.density != 1.0:
        density_cfg = DENSITY_CFG[args.task + "_density" + str(args.density)]
        cbins = density_cfg["cbins"]
        rbins = density_cfg["rbins"]
        max_npb = density_cfg["max_npb"]
        min_npb = density_cfg["min_npb"]
    data = env.pre_process_data(data,
                                args.outliers_percent,
                                args.noise_scale,
                                args.inpaint_ranges,
                                args.epsilon,
                                args.density,
                                cbins=cbins,
                                rbins=rbins,
                                max_npb=max_npb,
                                min_npb=min_npb)
    if args.safe_only:
        idx = (data["costs"]==0)
        for key in data.keys():
            data[key] = data[key][idx]
    elif args.unsafe_percent == 0.5:
        data_path = f"new_data/{args.task}_unsafe_50.pkl"
        with open(data_path, 'rb') as f:
            data = pickle.load(f)
    elif args.unsafe_percent == 0.2:
        data_path = f"new_data/{args.task}_unsafe_20.pkl"
        with open(data_path, 'rb') as f:
            data = pickle.load(f)
    elif args.unsafe_percent == 0.1:
        data_path = f"new_data/{args.task}_unsafe_10.pkl"
        with open(data_path, 'rb') as f:
            data = pickle.load(f)
    
    if args.conservative_cost_f:
        cost_func = env2cost_dict[args.task]
        data['costs'] = np.array([cost_func(next_obs) for next_obs in data["next_observations"]])

    # wrapper
    env = wrap_env(
        env=env,
        reward_scale=args.reward_scale,
    )
    env = OfflineEnvWrapper(env)

    # model & optimizer setup
    model = BCQL(
        state_dim=env.observation_space.shape[0],
        action_dim=env.action_space.shape[0],
        max_action=env.action_space.high[0],
        a_hidden_sizes=args.a_hidden_sizes,
        c_hidden_sizes=args.c_hidden_sizes,
        vae_hidden_sizes=args.vae_hidden_sizes,
        sample_action_num=args.sample_action_num,
        PID=args.PID,
        gamma=args.gamma,
        tau=args.tau,
        lmbda=args.lmbda,
        beta=args.beta,
        phi=args.phi,
        num_q=args.num_q,
        num_qc=args.num_qc,
        cost_limit=args.cost_limit,
        episode_len=args.episode_len,
        device=args.device,
    )
    print(f"Total parameters: {sum(p.numel() for p in model.parameters())}")

    def checkpoint_fn():
        return {"model_state": model.state_dict()}

    logger.setup_checkpoint_fn(checkpoint_fn)

    # trainer
    trainer = BCQLTrainer(model,
                          env,
                          logger=logger,
                          actor_lr=args.actor_lr,
                          critic_lr=args.critic_lr,
                          vae_lr=args.vae_lr,
                          reward_scale=args.reward_scale,
                          cost_scale=args.cost_scale,
                          device=args.device)

    # initialize pytorch dataloader
    dataset = TransitionDataset(dco(data),
                                reward_scale=args.reward_scale,
                                cost_scale=args.cost_scale)
    
    dataset_real = TransitionDataset(dco(data),
                                reward_scale=args.reward_scale,
                                cost_scale=args.cost_scale)
    
    dataset.seed(args.seed)
    dataset_real.seed(args.seed)

    dynamics_model = EnsembleDynamicsModel(
        obs_dim=env.observation_space.shape[0],
        action_dim=env.action_space.shape[0],
        hidden_dims=args.dynamics_hidden_dims,
        num_ensemble=args.num_ensemble,
        num_elites=args.num_elites,
        weight_decays=args.dynamic_weight_decays,
        with_cost=args.with_cost,
        device=args.device
    )
    cost_model = None
    dynamics_optim = torch.optim.Adam(
        dynamics_model.parameters(),
        lr=args.dynamics_lr
    )
    cost_model_optim = None
    dynamics_scheduler = None
    cost_model_scheduler = None
    if args.simple_scaler:
        scaler = SimpleScaler()
    else:
         scaler = StandardScaler()
    termination_fn = termination_fn_common
    dynamics = EnsembleDynamics(
        dynamics_model,
        cost_model,
        dynamics_optim,
        cost_model_optim,
        scaler,
        termination_fn,
        use_scheduler=args.use_scheduler,
        dynamics_scheduler=dynamics_scheduler,
        cost_model_scheduler=cost_model_scheduler,
        penalty_coef=args.penalty_coef,
        with_cost=args.with_cost,
        use_delta_obs=args.use_delta_obs,
        reward_scale=args.reward_scale,
        cost_scale=args.cost_scale,
        cost_coef=args.cost_coef
    )
    dynamics.load(env2dynamics[args.task])

    # for saving the best
    best_reward = -np.inf
    best_cost = np.inf
    best_idx = 0

    # training
    for step in trange(args.update_steps, desc="Training"):

        if step % args.rollout_interval == 0:
            rollout_std = args.rollout_std
            if args.rollout_std_decay:
                rollout_std = args.rollout_std * (1 - step / args.update_steps)
            for _ in range(args.rollout_epochs):
                init_data = dataset_real.sample(args.rollout_batch_size)
                init_obss = init_data[0]
                rollout_transitions, rollout_info = rollout(init_obss, args.rollout_length, trainer, dynamics, env2cost_dict[args.task], exp_sigma=rollout_std, use_unsafe_mask=args.use_unsafe_mask)
                # fake_buffer.add_batch(rollout_transitions['observations'], rollout_transitions['next_observations'], rollout_transitions['actions'], rollout_transitions['rewards'], rollout_transitions['costs'], rollout_transitions['dones'])
                dataset.add_data(rollout_transitions)

        batch = dataset.sample(args.batch_size)
        observations, next_observations, actions, rewards, costs, done = [
            torch.tensor(b, dtype=torch.float32).to(args.device) for b in batch
        ]
        trainer.train_one_step(observations, next_observations, actions, rewards, costs,
                               done)

        # evaluation
        if (step + 1) % args.eval_every == 0 or step == args.update_steps - 1:
            ret, cost, length = trainer.evaluate(args.eval_episodes)
            logger.store(tab="eval", Cost=cost, Reward=ret, Length=length)

            # save the current weight
            logger.save_checkpoint()
            # save the best weight
            if cost < best_cost or (cost == best_cost and ret > best_reward):
                best_cost = cost
                best_reward = ret
                best_idx = step
                logger.save_checkpoint(suffix="best")

            logger.store(tab="train", best_idx=best_idx)
            logger.write(step, display=False)

        else:
            logger.write_without_reset(step)


if __name__ == "__main__":
    train()
